import numpy as np
import pandas as pd
from scipy.stats import truncnorm
from scipy.optimize import minimize
from functools import partial
from typing import Dict, List, Tuple
import sys, time, random, os

# --- Environment Setup (Keep outside the loop)

new_directory = 'C:\\Users\\xt9\\Box\\Python Codes'
os.chdir(new_directory)

SEED_VALUE = 42
random.seed(SEED_VALUE)
np.random.seed(SEED_VALUE)

# --- Function Definitions (Keep outside the loop)

# calculate link formation probability
def p(xi, xj, design, prob):
    # Specify link formation probability function using boolean masks 
    # Input:  
    # - xi ~ column vector; xj ~ row vector
    # - prob: vector of link prob; design: specify support of x
    # Output:
    # - p_matrix: (n, n) array of link probabilities.
    if design == 11 or design == 12:
        mask_00 = (xi == 0) & (xj == 0)
        mask_01 = ((xi == 0) & (xj == 1)) | ((xi == 1) & (xj == 0))
        mask_11 = (xi == 1) & (xj == 1)
        p_matrix = np.zeros((xi.shape[0], xj.shape[1]), dtype=float)
        p_matrix[mask_00] = prob[0] 
        p_matrix[mask_01] = prob[1]
        p_matrix[mask_11] = prob[2]

    elif design == 2:
        # support of x is {0.0, 1.0, 2.0}.
        # prob: (6,) vector of probabilities [p(0,0), p(0,1), p(0,2), p(1,1), p(1,2), p(2,2)]
        p_matrix = np.zeros((xi.shape[0],xj.shape[1]), dtype=float)
        mask_00 = (xi == 0) & (xj == 0)      # p(0, 0) -> prob_vec[0]
        mask_11 = (xi == 1) & (xj == 1)      # p(1, 1) -> prob_vec[3]
        mask_22 = (xi == 2) & (xj == 2)      # p(2, 2) -> prob_vec[5]
        mask_01 = ((xi == 0) & (xj == 1)) | ((xi == 1) & (xj == 0)) # p(0, 1) -> prob_vec[1]
        mask_02 = ((xi == 0) & (xj == 2)) | ((xi == 2) & (xj == 0)) # p(0, 2) -> prob_vec[2]
        mask_12 = ((xi == 1) & (xj == 2)) | ((xi == 2) & (xj == 1)) # p(1, 2) -> prob_vec[4]
        p_matrix[mask_00] = prob[0]
        p_matrix[mask_01] = prob[1]
        p_matrix[mask_02] = prob[2]
        p_matrix[mask_11] = prob[3] # Note the index jump here
        p_matrix[mask_12] = prob[4]
        p_matrix[mask_22] = prob[5]
    
    return p_matrix

# generate X, G, M, given n, S, Lm, design
def generate_XMG(n: int, S: int, Lm: int, design: int, Sx: np.ndarray, prob: np.ndarray):

    X = np.zeros((n, S))
    G = np.zeros((n, n, S), dtype=bool) # use boolean for memory/speed
    M = np.zeros((n, S))
    M_raw = np.zeros((n, S))  # store raw m before discretization (only for definition 1)  

    for s in range(S):

        x   = np.random.choice(Sx, size=n)
        X_i = x[:,np.newaxis]  # turn x into column vector
        X_j = x[np.newaxis,:]  # turn x into row vector
        P_mat = p(X_i, X_j, design, prob) # access the n-by-n matrix of link formation prob
        g     = (np.random.rand(n,n)< P_mat).astype(int)  # generate the adjacency matrix

        m = np.zeros(n)
        # -- calculate average absolute difference --
        sum_abs_diff = np.sum( np.abs(X_i - X_j)*g, axis=1 )  # n x n matrix of absolute differences
        degree   = np.sum(g, axis=1)  # degree of each node
        m_raw = np.divide(sum_abs_diff, degree, out=np.zeros_like(degree, dtype=float), where = degree != 0) 
        # divide when degree > 0; otherwise default to 0.0

        # -- discretization --
        m_bins = np.linspace(Sx[0], Sx[-1], Lm + 1)
        Sm = 3 * (m_bins[:-1] + m_bins[1:]) / 2  # support of m_i (midpoints of the bins)
        m_indices = np.digitize(m_raw, m_bins) - 1  # assign each m[i] to the nearest midpoint
        m_indices = np.clip(m_indices, 0, Lm - 1)  # ensure indices are within valid range
        m = Sm[m_indices]  # replace m[i] with the corresponding midpoint
        # try: 
        #     m = m_midpoints[m_indices]  # replace m[i] with the corresponding midpoint
        # except IndexError:
        #     print("IndexError: m_indices out of bounds")
        #     print(f"m_indices: {m_indices}")
        #     print(f"m_raw: {m_raw}")
        #     print(f"m_midpoints: {m_midpoints}")
        #     pdb.post_mortem(sys.exc_info()[2])  
        X[:, s]     = x
        G[:, :, s]  = g # Store as boolean (or int, depending on memory/use)
        M[:, s]     = m
        M_raw[:,s]  = m_raw

    return X, G, M, M_raw, Sm

# calculate qn(jm, im, ix, jx) using simulated draws
def calculate_qn(X,G,M,Sx,Sm):
    Lm = len(Sm)
    Lx = len(Sx)
    S2 = X.shape[1]
    qn = np.zeros((Lm, Lm, Lx, Lx))
    for ix in range(Lx):
        for jx in range(Lx):
            for im in range(Lm):
                Mj = []
                for s in range(S2):
                    x, g, m = X[:, s], G[:, :, s], M[:, s]
                    
                    condition = (m == Sm[im]) & (x == Sx[ix])
                    if np.sum(condition) == 0:
                        continue
                    possible_i = np.where(condition)[0]

                    for ii in possible_i:

                        N_i = np.where(g[ii, :] == 1)[0]
                        if len(N_i) == 0:
                            continue                           
                    
                        N_i_prime_indices = np.where(x[N_i] == Sx[jx])[0]
                        if len(N_i_prime_indices) == 0:
                            continue
                    
                        for jj in N_i_prime_indices:
                            j_ind = N_i[jj] # get the absolute index
                            Mj.append(m[j_ind])

                if len(Mj) == 0:
                    continue

                Mj = np.array(Mj)
                
                for jm in range(Lm):
                    qn[jm, im, ix, jx] = np.sum(Mj == Sm[jm])
                
                qn[:, im, ix, jx] /= len(Mj)
    return qn

# estimate lambda using sample data
def estimate_lambda(m, x, a, Sx, Sm):
    Lx = len(Sx)
    Lm = len(Sm)
    lam = np.zeros((Lx, Lm))
    for ix in range(Lx):
        for im in range(Lm):
            condition = (m == Sm[im]) & (x == Sx[ix])
            numerator = np.sum(a*condition)
            denominator = np.sum(condition)
            if denominator > 0:
                lam[ix, im] = numerator / denominator
    return lam

# define objective function for NLS
def Gn(gamma,Sx,Sm,Ix,Im,lam,Sigma,Lx):
        beta, delta, tau = gamma
        f = np.mean((lam[Ix, Im] - Sx[Ix] * beta - Sm[Im] * delta - tau * Sigma/Lx)**2)
        return f

# -------------------------------------------------------------------------
#                        Main Simulation Function
# -------------------------------------------------------------------------

def run_simulation(n: int, design: int, case: int, S: int = 200, S2: int = 10000, convergent: int = 1, Lm: int = 5):

    start_time = time.time()

    # ---  PART 0: define parameters and specifications

    # structural parameters and error distribution
    if case == 3:
        beta, delta, tau = 1.5, 0.8, 0.7
        # For scipy.stats.truncnorm, bounds a and b are (lower-mean)/std, (upper-mean)/std
        a, b = (-1.5 - 0) / 1, (1.5 - 0) / 1
        d_epsilon = truncnorm(a=a, b=b, loc=0, scale=1)
    elif case == 1:
        beta, delta, tau = 3.0, 1.5, 0.8
        a, b = (-1.5 - 0) / 1, (1.5 - 0) / 1
        d_epsilon = truncnorm(a=a, b=b, loc=0, scale=1)
    elif case == 2:
        beta, delta, tau = 3.0, 1.5, 0.8
        a, b = (-2.0 - 0) / 1, (2.0 - 0) / 1
        d_epsilon = truncnorm(a=a, b=b, loc=0, scale=1.5)

    # support of x and link formation probabilities ---
    if design == 11: # x ~ Uniform{0,1}
        Sx = np.array([0.0, 1.0])  # Support of x
        Lx = len(Sx)  # Length of the support of x
        if convergent == 0:  # Link formation probability is fixed
            prob = np.array([0.06, 0.04, 0.06])
        elif convergent == 1:  # Link formation probability is converging as n→∞
            mu = np.array([10, 5, 10])  # μ ≡ n⋅p
            prob = mu / n  # p = μ/n as n→∞
    elif design == 12: # x ~ Uniform{0,1}
        Sx = np.array([0.0, 1.0])  # Support of x
        Lx = len(Sx)  # Length of the support of x
        if convergent == 0:  # Link formation probability is fixed
            prob = np.array([0.06, 0.04, 0.06])
        elif convergent == 1:  # Link formation probability is converging as n→∞
            mu = np.array([20, 10, 20])  # μ ≡ n⋅p
            prob = mu / n  # p = μ/n as n→∞
    elif design == 2: # x ~ Uniform{0,1,2}
        Sx = np.array([0.0, 1.0, 2.0]) # Support of x
        Lx = len(Sx) # Length of the support of x
        if convergent == 0:  # Link formation probability is fixed
            prob = np.array([0.06, 0.04, 0.03, 0.06, 0.04, 0.06])
        elif convergent == 1:  # Link formation probability is converging as n→∞
            mu = np.array([10, 8, 5, 12, 8, 10])  # μ ≡ n⋅p
            prob = mu / n  # p = μ/n as n→∞

    # ---  Part 1: solve for asymptotic moments via simulation (using S2 draws of network)

    # simulate S2 networks and store x, g, m
    X, G, M, M_raw, Sm = generate_XMG(n, S2, Lm, design, Sx, prob)
    
    # calculate qn(jm, im, ix, jx) ---
    qn = calculate_qn(X,G,M,Sx,Sm)

    # calculate  λn ussing contraction mapping algorithm ---
    tol = 1e-6
    maxit = 5000
    lambda_n = np.ones((Lx, Lm))
    for iteration in range(maxit):
        lambda_n_prime = np.zeros((Lx, Lm))
        for ix in range(Lx):
            for im in range(Lm):
                E_lambda_n = 0
                for jx in range(Lx):
                    inner_sum = 0
                    for jm in range(Lm):
                        inner_sum += lambda_n[jx, jm] * qn[jm, im, ix, jx]
                    # E_lambda_n += omega(ix, jx) * inner_sum
                    E_lambda_n += (1/Lx) * inner_sum
                
                lambda_n_prime[ix, im] = Sx[ix] * beta + Sm[im] * delta + tau * E_lambda_n
        
        err = np.max(np.abs(lambda_n - lambda_n_prime))
        lambda_n = lambda_n_prime.copy()
        
        # print(f"Iteration {iteration + 1}: error = {err}")
        if err < tol:
            print(f"Iteration of strategy function reached tolerance after {iteration} rounds.")
            break
        if iteration == maxit - 1:
            print("Maximum number of iterations reached.")

    # --- Part 2: simulate Monte Carlo samples (S networks, each with n individuals)

    # simulate choices in S2 simulated networks (using previously drawn X, G, M)
    Epsilon = d_epsilon.rvs(size=(n,S2))  # draw error terms

    # calculate expected value of a's for each (x,m) pair
    Ea = np.zeros((Lx,Lm)) # initialize expected value of a's
    for ix in range(Lx):
        for im in range(Lm):
            # Ea[ix, im] = 0
            for jx in range(Lx):
                inner_sum = 0
                for jm in range(Lm):
                    inner_sum += lambda_n[jx, jm] * qn[jm, im, ix, jx]
                Ea[ix, im] += (1/Lx) * inner_sum  # equal weights: \omega(ix, jx) = 1/Lx

    # generate choices for each individual in each simulated network
    A = np.zeros((n, S2))
    for ix in range(Lx):
        for im in range(Lm):
            mask = (X == Sx[ix]) & (M == Sm[im])
            A[mask] = X[mask] * beta + M[mask] * delta + tau * Ea[ix,im] + Epsilon[mask]

    # draw random sample of size S from the larger pool of S2 simulations
    sample_index = np.random.choice(S2, size=S, replace=False)
    x_data = X[:, sample_index]
    m_data = M[:, sample_index]
    a_data = A[:, sample_index]
    g_data = G[:, :, sample_index] 

    # --- Part 3: estimate parameters using NLS
    param_est = np.zeros((3,S))  # store parameter estimates from each simulation
    flag_mat = np.zeros((S,), dtype=bool)  # store convergence flag from each simulation
    funval_mat = np.zeros((S,))  # store objective function value from each simulation  
    # bounds = [(beta - 1, beta + 1), (delta - 1, delta + 1), (0.0, 0.999)]

    for s in range(S): # loop over the samples

        # -- estimate transition probabilities qn(.)
        qn_hat_s = calculate_qn(x_data[:,s:s+1], g_data[:,:,s:s+1], m_data[:,s:s+1], Sx, Sm)

        # -- estimate lambda_n
        lambda_hat_s = estimate_lambda(m_data[:,s], x_data[:,s], a_data[:,s], Sx, Sm)

        # -- prepare data for NLS estimation --
        Ix_s = np.empty(n, dtype=int)
        Im_s = np.empty(n, dtype=int)
        Ix_s[:] = np.searchsorted(Sx, x_data[:,s]) # Sx and Sm must be sorted arrays
        Im_s[:] = np.searchsorted(Sm, m_data[:,s])
        qn_subset_s = qn_hat_s[:,Im_s, Ix_s, :]
        qn_indexed_s = np.transpose(qn_subset_s, (1,0,2))  # qn_indexed.shape is (n,Lm,Lx)
        # use advanced indexing to get qn[:, Im[i], Ix[i], jx] for each i in I
        Sigma_s = np.zeros(n)
        for jx in range(Lx):
            for jm in range(Lm):
                Sigma_s += qn_indexed_s[:, jm, jx] * lambda_hat_s[jx, jm]
        
        # define objective function (handle fixed parameters using `partial` from `functools`)
        Gn_s = partial(Gn, Sx=Sx, Sm=Sm, Lx=Lx, Ix=Ix_s, Im=Im_s, lam=lambda_hat_s, Sigma=Sigma_s)

        # -- perform optimization --
        result_s        = minimize(fun = Gn_s, x0 = np.array([beta-1,delta+1,.1]), method='Nelder-Mead', # bounds=bounds, 
                                        options={'xatol': 1e-3, 'maxiter':10000,'fatol':1e-4})
        flag_mat[s]     = result_s.success
        funval_mat[s]   = result_s.fun    
        param_est[:,s]  = result_s.x

    # save results and print runtime
    filename = f'output_n{n}_defn1_design{design}_case{case}' 
    fullname = os.path.join(os.getcwd(), 'output', filename)
    bias = np.mean(param_est, axis=1) - np.array([beta, delta, tau])
    var  = np.var(param_est, axis=1)
    mse  = np.mean((param_est - np.array([beta, delta, tau])[:, np.newaxis])**2, axis=1)
    np.savez(fullname, bias=bias, var=var, mse=mse, param_est=param_est, flag_mat=flag_mat, funval_mat=funval_mat)
    # true_params = np.array([beta, delta, tau])
    # true_params_reshaped = true_params[:, np.newaxis]
    # print(f"File saved to directory: {os.getcwd()}")

    end_time = time.time()
    run_time = end_time - start_time
    print(f"The code for n={n}, design={design}, case={case} ran in {run_time:.4f} seconds.")

    # LOG_FILENAME = f'simulation_output_n{n}_design{design}_convergent{convergent}.txt'
    # with open(LOG_FILENAME, 'w') as f:
    #     original_stdout = sys.stdout  # Save a reference to the original standard output
    #     sys.stdout = f                # Change standard output to the file

    #     print(f"Sample Size = {n}   Design = {design}  Convergent = {convergent}")
    #     print(f"Bias = {np.mean(param_est, axis=1)-true_params}")
    #     print(f"Var = {np.var(param_est, axis=1)}")
    #     print(f"MSE = {np.mean((param_est-true_params_reshaped)**2, axis=1)}")
        
    #     sys.stdout = original_stdout

# ----------------------------------------------------------------------
# Outer Loop to Run All Scenarios (The one time manual run)
# ----------------------------------------------------------------------

if __name__ == '__main__':
    # define the scenarios: (n, design, case)
    scenarios = [
        # ( 50, 11, 1), ( 50, 12, 1), ( 50, 2, 1),
        # ( 50, 11, 2), ( 50, 12, 2), ( 50, 2, 2),
        (100, 11, 1), (100, 12, 1), (100, 2, 1),
        (100, 11, 2), (100, 12, 2), (100, 2, 2),
        (200, 11, 1), (200, 12, 1), (200, 2, 1),
        (200, 11, 2), (200, 12, 2), (200, 2, 2),
        (400, 11, 1), (400, 12, 1), (400, 2, 1),
        (400, 11, 2), (400, 12, 2), (400, 2, 2)  ]

    # print(f"Starting simulation of {len(scenarios)} scenarios...")
    
    for n_val, design_val, case_val in scenarios:
        print(f"\n--- Running Scenario: defn = 1, n={n_val}, design={design_val}, case={case_val} ---")
        try:
            run_simulation(n_val, design_val, case_val)
            print(f"Finished Scenario: defn = 1, n={n_val}, design={design_val}, case={case_val}")
        except Exception as e:
            print(f"!!! Error in Scenario: defn = 1, n={n_val}, design={design_val}, case={case_val} !!!")
            print(f"Error details: {e}")
            continue # Continue to the next scenario even if one fails